CREATE OR REPLACE FUNCTION "pipeline"."isolation_forest"("source_table" varchar, "out_table" varchar, "id_col" varchar, "feature_cols" varchar, "n_estimators" int4, "max_samples" float8, "contamination" float8)
  RETURNS "pg_catalog"."text" AS $BODY$
  import numpy as np
  from sklearn.ensemble import IsolationForest
  import json
  import time
  
  #初始化返回值
  result = {
    "status": 0,
    "error_msg": "success",
    "result": {
      "input_params": {
        "out_table": out_table,
        "source_table": source_table,
        "id_col": id_col,
        "feature_cols": [],
        "n_estimators": n_estimators,
        "max_samples": max_samples,
        "contamination": contamination
      },
      "output_params": []
    }
  }

  #读数据
  def get_data_sql(source_table, id_col, feature_cols):  
    sql_str = "select %s, %s from %s" %(id_col, feature_cols, source_table)
    return sql_str
  
  #获取表格属性信息
  def getTableMetaInfo(table_name):
    tmps = table_name.split(".")
    if len(tmps) > 0:
      table_name = tmps[1]	  
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
      meta[line['column_name']] = line['data_type']
    return meta

  #创建输出表
  def saveOutTable(datas, out_table, id_col, meta, features):
    value_tpl = "({}, {}, '{}')"
    if meta.get(id_col) in ["text", "character varying", "varchar", "char", "date", "character", 
        "timestamp", "time", "timestamp_with_timezone", "longvarchar", "longnvarchar", "nvarchar", "nchar"]:
      value_tpl = "('{}', {})"
    #table_name = "%s_%s" %(out_table, int(time.time()))
    table_name = out_table
    defBody = []
    for fea in features:
        defBody.append("{} {}".format(fea, meta.get(fea)))
    sql_str = "CREATE TABLE %s (%s %s, %s, label varchar);" %(table_name, id_col,  meta.get(id_col), ",".join(defBody))
    try:
        plpy.execute(sql_str)
    except Exception as e:
        raise e
        #return {"error": str(e), "sqlText": sql_str}
       
    sql_str = "INSERT INTO %s (%s, %s, label) VALUES" %(table_name, id_col, ",".join(features))
    _values = []
    for item in datas:
      s = [str(x) for x in item[1:-1]]
      tmp = value_tpl.format(item[0], ",".join(s), item[-1])
      _values.append(tmp)
    sql_str += ",".join(_values)
    try:
        plpy.execute(sql_str)
    except Exception as e:
        raise e
        #return {"error": str(e)}
    tmpcols = [id_col]
    tmpcols.extend(features)
    tmpcols.append("label")
    return {"out_table_name": table_name, "cols": tmpcols}

  def throwError(e):
    global result
    result["status"] = 500
    result["error_msg"] = str(e)

  #主程序
  def begin_alg(source_table, out_table, id_col, feature_cols, n_estimators, max_samples, contamination):
    global result
    sql_str = get_data_sql(source_table, id_col, feature_cols)
    features = feature_cols.split(",")
    result["result"]["input_params"]["feature_cols"] = features
    ids = []
    rv = None
    try:
      rv = plpy.execute(sql_str)
    except Exception as e:
      throwError(e)
      return json.dumps(result)
    datas = []
    for item in rv:
      line = []
      tmpData = []
      tmpData.append(item[id_col])
      #ids.append(item[id_col])
      for feature in features:
        line.append(float(item[feature.replace('"', "")]))
      tmpData.extend(line)
      datas.append(line)
      ids.append(tmpData)
    datas = np.array(datas, dtype=np.float64)
    clf = None
    try:
      clf = IsolationForest(n_estimators = n_estimators, max_samples = max_samples, contamination = contamination)
      clf.fit(datas)
    except Exception as e:
      throwError(e)
      return json.dumps(result)
    res = clf.predict(datas).tolist()

    out_datas = []
    for i in range(len(ids)):
      line = ids[i]
      line.append(str(res[i]))
      out_datas.append(line)
    meta = getTableMetaInfo(source_table)
    saveRet = None
    try:
      saveRet = saveOutTable(out_datas, out_table, id_col, meta, features)
    except Exception as e:
      plpy.execute("DROP TABLE IF EXISTS {}".format(out_table))
      throwError(e)
      return json.dumps(result)
    tmpTable = {
        "out_table_name": saveRet["out_table_name"],
        "output_cols": saveRet["cols"]
    }
    result["result"]["output_params"].append(tmpTable)    
    #result["result"]["output_params"]["out_table_name"] = saveRet["out_table_name"]
    #result["result"]["output_params"]["output_cols"] = saveRet["cols"]

    return json.dumps(result)

  if __name__ == '__main__':
    global result
    ret = begin_alg(source_table, out_table, id_col, feature_cols, n_estimators, max_samples, contamination)
    return ret
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100